import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from Network.network import Network
from Network.net_types import network_type
from Network.network_utils import pytorch_model, get_acti
from Network.Dists.mask_utils import expand_mask, apply_symmetric
from Network.Dists.net_dist_utils import init_key_query, init_forward_args
import copy, time


class DiagGaussianForwardPadMaskNetwork(Network):
    '''
    Handles object-based mask networks, including pair, asymmetric and symmetric
    factor networks. Because masking is often subnet dependent, does not perform
    the masking operation, but does perform key-query encoding
    '''
    def __init__(self, args):
        super().__init__(args)

        self.fp = args.factor
        self.use_valid = args.use_valid
        self.key_args, self.query_args, self.key_query_encoder = \
            init_key_query(args) # even if we don't embed, init a key_query that just slices
        args.factor.key_dim, args.factor.query_dim = self.key_query_encoder.key_dim, self.key_query_encoder.query_dim
        mean_args = init_forward_args(args)
        mean_args.activation_final = "none"
        self.mean = network_type[args.net_type](mean_args)
        std_args = init_forward_args(args)
        std_args.activation_final = "none"
        self.std = network_type[args.net_type](std_args)
        self.model = [self.key_query_encoder, self.mean, self.std]
        
        self.base_variance = args.dist.base_variance # hardcoded based on normalized values, base variance 1% of the average variance

        self.pre_embed = args.factor.pre_embed
        self.object_dim = args.factor.object_dim
        self.embed_dim = args.embed_dim
        self.append_broadcast_mask = args.factor_net.append_broadcast_mask

        self.train()
        self.reset_network_parameters()

    def reset_environment(self, factor_params):
        self.fp = factor_params
        self.key_query_encoder.reset_environment(factor_params)
        if hasattr(self.inter_models[0], "reset_environment"): 
            for im in self.inter_models:
                im.reset_environment(factor_params)
        if hasattr(self.mean, "reset_environment"):
            self.mean.reset_environment(factor_params)
            self.std.reset_environment(factor_params)

    def forward(self, x, m=None, valid =None, dist_settings=None, ret_settings=None, grad_settings=[]):
        # keyword hyperparameters are used only for consistency with the mixture of experts model
        # x: batch_size, input_dim
        # m: batch_size, num_keys, num_queries
        # dist_settings: soft, mixed, flat, full
        # return settings: embedding, reconstruction, weights
        # returns (mean,var) tuple, mask (if generated) and info, which contains ret_settings defined values
        x = pytorch_model.wrap(x, cuda=self.iscuda)
        if "input" in grad_settings: x.requires_grad = True
        m = pytorch_model.wrap(m, cuda=self.iscuda) if m is not None else m
        valid = pytorch_model.wrap(valid, cuda=self.iscuda) if valid is not None else valid
        
        # if pre_embed, assumes that x is already a tuple of key, query
        keys, queries = self.key_query_encoder(x) if not self.pre_embed else x # [batch size, embed dim, num keys], [batch size, embed dim, num queries]
        if "embed" in grad_settings: keys.requires_grad, queries.requires_grad = True, True

        # perform necessary slicing and broadcasting operations
        mv = self.key_query_encoder.slice_masks(m, x.shape[0], keys.shape[1], queries.shape[1])
        validv = self.key_query_encoder.slice_masks(valid, x.shape[0], keys.shape[1], queries.shape[1])
        # merge mask and valid
        if x.shape[0] == 1: print(x.shape, keys.shape, queries.shape, m.shape, valid.shape, mv.shape, validv.shape)
        if not hasattr(self, "use_valid") or self.use_valid: mv = (mv * validv if m is not None else validv) if valid is not None else mv
        mean = self.mean(keys, queries, mv, ret_settings=ret_settings)
        var = self.std(keys, queries, mv, ret_settings=ret_settings)
        # merges the mean and variance masks, if necessary
        if m is not None and "mask" in ret_settings: m = torch.stack([mean[1],var[1]], dim=-1).max(dim=-1)[0]
        meanv = (torch.tanh(mean[0])).view(x.shape[0],-1)
        varv = (torch.sigmoid(var[0]) + self.base_variance).view(x.shape[0],-1)
        return (meanv, varv), m, (x, keys, queries, mean[1:], var[1:])
        